% We wish to track mfr patterns that occur during training and see if they
% crop up later on in sequence completion when we don't know the time frame
% of the switch between patterns.
%
% It can be used in a free recall situation in which you don't know what the
% network is doing because you are not presenting a pattern.
%
% Analyze the data from the spiking network for some task consisting of
% trials of equal duration.
%
%
%  function [rs sig_cells] = analyzepatterns( area_name, type_name, phase_start_msec, phase_end_msec , start_time_msecs, end_time_msecs,significance_level, save_seconds)
% 
%  phase_start_msec, phase_end_msec:   the start and
%		end times (inclusive) within the trial to analyze the data.  Start counting at 0 for the first msec.
%  start_time_msec, end_time_msec: start and end times of the experiments.
%
%  save_seconds = the number of seconds saved in 1 spike*.dat file; must
%   be an integer.
%  RETURNS:
%	matrix of probabilities from ranksum, comparing spike counts of neurons during the specified task phase time
%	for significantly different medians.  rs(neuron, pattern i , pattern j);
%
%  EXAMPLE:
%
%
% analyzeSequence('S1','p23',200,999,4000,12000,24000,40000)
%
% Where the last 2 parameters are now the testing time for finding the pattern match.  
% The first parameters are as before and require a known training period with controlled timing.  
% The spike data during this time are used to find the patterns, while the testing period is 
% searched in 50 msec windows for matches to these firing rate patterns.  The main result is 
% in figure 101 which shows the match to each sequence.  The matching criterion is a 
% normalized dot product, so perfect matches are near 1, while orthogonal patterns yield a 0.  
% This works well for labelled line coding with high firing rates.  If the firing rate is low, 
% you might need to increase the WINDOW variable.
%
%  This example will analyze mean firing rate patterns for each pattern presented to the network
%  between 4000 msec and 12000 msec into the
%  simulation. During each trial, the spikes will be analyzed only between
%  reletive time 200 and 999 msec to build a mfr template for each pattern.
%  Then during 24000 to 40000msec, the program will
%  go back and try to find matches to the patterns with a spike bin width
%  of 50 msec.
%  The file patternhist.dat looks like this where each line is a trial and
%  the first entry is the start time of a trial in msec and the second
%  number is the pattern number in range [0,N], where N is the max number
%  of patterns used in the experiments (no gaps.)
% 
% 
% 0 0
% 250 1
% 500 2
% 750 3
% 1000 0
% 1250 1
% 1500 2
% 1750 3
% 2000 0
% 2250 1
% 2500 2
% 2750 3
% ...
% 
%
% If area_name == 'mfr', then the rest of the params are unnecessary; it
% will be assumed that the mfr structure is precalculated and rows and cols are set.
% 
% mfr is defined as follows: mfr(pattern).spikes(neuron,trial) =
%   the spikes per second that 'neuron' fires in response to 'pattern' on
%   'trial'.  Each pattern can have a different number of pattern
%   presentations.

function [rs sig_cells] = analyzeSequence( area_name, type_name, phase_start_msec, phase_end_msec ,train_start_time_msecs, train_end_time_msecs, start_time_msecs, end_time_msecs,significance_level, save_seconds)

%global samples cats
global mfr rows cols PSTATES

if nargin < 9
    significance_level=0.05;
    save_seconds = 1;
else 
    if nargin < 10
        save_seconds = 1;
    end
end




mfr=[];

load ('patternhist.dat');

phase_start_msec = phase_start_msec + 1;
phase_end_msec = phase_end_msec + 1;




read_groups; % Assumes groups.dat has the neural area data


index = [];
for i=1:length(areas)
    if strcmp( deblank(areas{i}), area_name) && strcmp( type_name, deblank(celltype{i})) 
        index = i
        break;
    end
end
if isempty(index)
    disp([area_name ' ' type_name ' is not in groups.dat.']);
    return;
end

first = neuron_id(index,1) 
last  = neuron_id(index,2) 

num_neurons = max(neuron_id(:,2))+1

first = first + 1;
last = last + 1;
rows = la(index);
cols = lb(index);
N=rows*cols



num_patterns = max( patternhist(:,2))+1


count = zeros( 1, num_patterns );


% find the file we need to open
%filename = ['sdat' numTOstr( ceil( (start_time_msecs+phase_start_msec)/(save_seconds*1000))*save_seconds,6) '.dat' ]
%sdat = load(filename);
%S=sparse( sdat(:,2)+1, sdat(:,1)+1, 1);
filename = '';

trial_indices_to_analyze = find(patternhist(:,1)>=train_start_time_msecs & patternhist(:,1)<train_end_time_msecs);

for i = trial_indices_to_analyze'

    t = patternhist( i, 1 );

    % find the file we need to open
    current_filename = ['spikes' num2str( ceil( ( t+phase_start_msec)/(save_seconds*1000))*save_seconds*1000,15)  '.dat']

    % Load new filename if necessary.
    if strcmp(current_filename, filename) == 0
        filename = current_filename;
        sdat = load(filename);
        max_time_msec = ceil( ( t+phase_start_msec)/(save_seconds*1000))*save_seconds*1000
        S=sparse( sdat(:,2)+1, sdat(:,1)+1, 1, max(num_neurons,max(sdat(:,2)+1)), max(max_time_msec,max(sdat(:,1))) );

    end

    % If this pattern presentation is split across files, ignore it.	
    end_filename = ['spikes' num2str( ceil( ( t +phase_end_msec-1)/(save_seconds*1000))*save_seconds*1000,15)  '.dat'];
    if strcmp(end_filename, filename) == 0
        continue;
    end

    current_pattern = patternhist( i , 2 )+1;  
    count(current_pattern) = count(current_pattern) + 1;
    c=count(current_pattern);
    % Gather mean rates for stats.
    first
    last 
    (t+phase_start_msec)
    (t+phase_end_msec)
    disp('****************************')
    mfr(current_pattern).spikes(:, c ) = full(sum( S( first:last, (t+phase_start_msec):(t+phase_end_msec-1)), 2));
current_pattern
end




count

k=0;
figure(100+k); k= k + 1;


disp('Mean firing rate of cells to each pattern');
meanresponses=zeros(rows,cols,num_patterns);
mfrtemplate=zeros(rows*cols,num_patterns);
for i = 1:num_patterns


	%figure(100+k); k= k + 1;
    side = ceil(sqrt(num_patterns));
    subplot(side,side,i);
    
    meanresponses(:,:,i)=transpose(reshape( mean(mfr(i).spikes,2)*1000/(phase_end_msec-phase_start_msec), rows, cols ));
    imagesc(meanresponses(:,:,i));
    
    mfrtemplates(:,i) = mean(mfr(i).spikes,2);
    
	%colormap(gray);
    colorbar;
    set(gca,'XTickLabel',{});
	%title(['Mean rate (sp/sec) of cells in area ' area_name type_name ' to pattern ' num2str(i) ]);
	
end
mfrshapes = mfrtemplates > 1;
mfrtemplates = normc(mfrtemplates);


% Now that we have the training patterns, now look for the patterns during
% the test phase.  Move in short time windows and look for matches.
WINDOW = 50;  % msec.

filename = '';
i=1;
% WINDOW should evenly divide 1000.
PSTATES = zeros(num_patterns, floor((start_time_msecs-end_time_msecs)/WINDOW));
for t = start_time_msecs:WINDOW:end_time_msecs
    
    % find the file we need to open
    current_filename = ['spikes' num2str( ceil( (t+1)/(save_seconds*1000))*save_seconds*1000,15)  '.dat']
    
    % Load new filename if necessary.
    if strcmp(current_filename, filename) == 0
        filename = current_filename;
        sdat = load(filename);
        max_time_msec = ceil( (t+1)/(save_seconds*1000))*save_seconds*1000
        S=sparse( sdat(:,2)+1, sdat(:,1)+1, 1, max(num_neurons,max(sdat(:,1)+1)), max(max_time_msec,max(sdat(:,1))) );        
    end
    
    % If this pattern presentation is split across files, ignore it.
    end_filename = ['spikes' num2str( ceil( ( t+1 +WINDOW-1)/(save_seconds*1000))*save_seconds*1000,15)  '.dat'];
    if strcmp(end_filename, filename) == 0
        disp('warning: ignoring data because the bin window crossed file boundaries.');
        disp(t)
        disp(t+WINDOW-1)
        continue;
    end
    
    % Gather mean rates for stats.
    disp('****************************')
    timebin = full(sum( S( first:last, t:(t+WINDOW-1)), 2));
    
    % Cut out just the neurons which match the template for each pattern.
    timebin = mfrshapes.*repmat(timebin,1,num_patterns);
    timebin = normc(timebin);
    timebin(timebin==1)=0; % if we had all zeros in a column, normc will set them all to 1's. Replace with zeros.
    
    % We have the mfr in this time bin.  Now we have to see which training 
    % pattern it matches.
    PSTATES(:,i) = sum(mfrtemplates.*timebin)';
    if any(PSTATES(:,i) > 1)
        disp('How is this possible?');
    end
    
    i=i+1;
end


figure(100+k); k= k + 1;
plot(start_time_msecs:WINDOW:end_time_msecs,PSTATES');
title('match strength for each pattern over time');
xlabel('T (msec)','FontSize',12)
ylabel('Match strength','FontSize',12)
title('Match strength for each pattern over time','FontSize',14)

figure(100+k); k= k + 1;
[py,px]=size(PSTATES);
imagesc([start_time_msecs/1000+1/20:start_time_msecs/1000+px/20],[1:py],PSTATES);
title('match strength for each pattern over time');
colormap(gray)
ylabel('Pattern number','FontSize',12)
title('Match strength for each pattern over time','FontSize',14)
xlabel('T (sec)','FontSize',12)



figure(100+k); k= k + 1;

imagesc(mean(meanresponses,3));
colorbar
title('Mean firing rate for all neurons across all patterns','FontSize',14);
drawnow

figure(100+k); k= k + 1;
sortedmfr=sort(meanresponses,3,'descend');
sortedmfr=reshape(sortedmfr,rows*cols,num_patterns);
%plot(mean(sortedmfr));
errorbar(mean(sortedmfr),std(sortedmfr));
title('Grand tuning curve over all neurons (mfr +/- std.)','FontSize',14);
ylabel('Mean firing rate (sp/sec)','FontSize',12);
xlabel('pattern number sorted from best to worst','FontSize',12);
drawnow

figure(100+k); k= k + 1;
imagesc(sortedmfr);


figure(100+k); k= k + 1;
b=bitsinfo(sortedmfr,2.5:5:(max(sortedmfr(:,1))+2.49)); % bins of ~ 5 sp/sec.
hist(b,20);
title('histogram of bits of information conveyed by neurons','FontSize',14);
drawnow

figure(100+k);k=k+1;
[best_response best_indices]=max(meanresponses,[],3)
imagesc(best_indices, [0 num_patterns]); % start scale at 0 for comparison with significance plot
colorbar
axis off
title('Optimal pattern for each cell','FontSize',14);
drawnow

figure(100+k);k=k+1;
hist(best_indices(:),num_patterns);
title('Histogram of neurons responding best to each pattern','FontSize',14);
drawnow


% See if we can classify the samples given the network response vector

train_classifier = true;
if train_classifier
    
    samples = [];
    cats=[];
    testing_samples = [];
    testing_cats=[];
    
    for i = 1:num_patterns
        
        training_patterns = floor(2/3*count(i));
        testing_patterns = count(i)-training_patterns;
        
        samples =[samples; mfr( i ).spikes(:,1:training_patterns)'];
        cats = [cats ones(1,training_patterns)*i];
        
        testing_samples = [testing_samples; mfr(i).spikes(:,training_patterns+1:end)'];
        testing_cats = [testing_cats ones(1,testing_patterns)*i];
    end
    
    % Try a backprop classifier
    T = full(ind2vec(cats)) ; % turn category indices into appropriate outputs.
    
    bpnet=newlind(samples',T);
    Y=sim(bpnet,samples');
    
    [m Yc] = max(Y);
    
    percent_correct_bp = sum(Yc==(vec2ind(T)))/length(Yc)*100
    cMat1 = confusionmat(cats,Yc) % the confusion matrix
    
    % now on test set.
    T = full(ind2vec(testing_cats)) ; % turn category indices into appropriate outputs.
    Y=sim(bpnet,testing_samples');
    
    [m Yc] = max(Y);
    
    percent_correct_bp = sum(Yc==(vec2ind(T)))/length(Yc)*100
    cMat1 = confusionmat(testing_cats,Yc) % the confusion matrix
    
    
end



% Show significantly preferential responses to each pattern.



for i = 1:num_patterns
	for j = 1:num_patterns
  	    if ( count(j) > 1 && count(i) > 1)
		for n = 1:N
			[junk significant]=ranksum( mfr( i ).spikes(n, : ) , mfr( j ).spikes(n, : ) ,significance_level);
            rs(n,i,j) = significant && mean(mfr( i ).spikes(n, : ))  > mean(mfr( j ).spikes(n, : )); 
%if sum( mfr( i ).spikes(n, :) ) > 0
% [i j]
% mfr( i ).spikes(n, : ) 
% mfr( j ).spikes(n, : ) 
%end

		end
		num_significant(i,j)=sum(rs(:,i,j));
		%if j>i
		%	figure(100+k);
		%	imagesc(reshape( rs(:,i,j), rows, cols ))
		%	colormap(gray);
		%	title(['Cells in area ' areaname ' that distinguish pattern ' num2str(i) ' from pattern ' num2str(j) ' statistically']);
		%	k=k+1;
		%end
	    end
	end
	disp('.');
end

disp('number of cells which distinguish pattern "row" from pattern "col"');
num_significant

discriminators = zeros(1,num_patterns);
merged_discriminator_map = zeros(N,1);

for i = 1:num_patterns
	temp = ones(N,1);
	for j = 1:num_patterns

		if i ~= j
			temp = temp & rs(:,i,j);
		end

    end

    merged_discriminator_map = merged_discriminator_map + i*temp;
    discriminators(i) = sum(temp);

	figure(100+k); k= k + 1;

	%subplot(num_patterns,1,i);
	imagesc(transpose(reshape( temp, rows, cols )))
	colormap(gray);
	title({['Cells in area ' area_name type_name ' that distinguish pattern ' num2str(i)],[ ' from ALL others statistically (p< ' num2str(significance_level) ', Wilcoxon RS); t=' num2str(phase_start_msec) ' ' num2str(phase_end_msec) ' and respond to maximally to it.' ]},'FontSize',14);
    size(temp)
	disp('Significant cell numbers (start counting at 0 as in C): ');
    sig_cells(i).list = (find(temp)+first-2)';
end

figure(100+k);k=k+1;
bar(discriminators);
title({['Histogram of cells in area ' area_name type_name ' that distinguish one pattern from ALL others'],['statistically (p< ' num2str(significance_level) ', Wilcoxon RS); t=' num2str(phase_start_msec) ' ' num2str(phase_end_msec) ' and respond to maximally to it.' ]},'FontSize',14);


figure(100+k);k=k+1;
imagesc(transpose(reshape( merged_discriminator_map, rows, cols )))
title({['Cells in area ' area_name type_name ' that distinguish one pattern from ALL others'],['statistically (p< ' num2str(significance_level) ', Wilcoxon RS); t=' num2str(phase_start_msec) ' ' num2str(phase_end_msec) ' and respond to maximally to it.' ]},'FontSize',14);
axis off
colorbar

merged_discriminator_map
for i = find(merged_discriminator_map')
   disp(['Responses for unit ' num2str(i)]);
   for j = 1:num_patterns 
      disp([j mfr(j).spikes(i,:)]); 
   end
end


% Display how each cell responds to each pattern.

% figure(100+k); k= k + 1;
% disp('Mean firing rate of each cell grouped by patterns');
% cell = 1;
% for i = 1:rows
%     for j = 1:cols
% 
%         subplot(rows,cols,cell);
%         temp = zeros(1,num_patterns);
%         for pat = 1:num_patterns
%             temp(pat)=mean(mfr(pat).spikes(cell),2)*1000/(phase_end_msec-phase_start_msec);
%         end
%         bar(temp);
%         axis off;
%         ylim([0 20]);
%         xlim([1 num_patterns]);
%         %imagesc(transpose(reshape( mean(mfr(i).spikes,2)*1000/(phase_end_msec-phase_start_msec), rows, cols )))
% 
%         cell = cell + 1;
%     end
% end

